package easykafka

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"gitee.com/Cookie_XiaoD/easykafka/spec"
	"github.com/Shopify/sarama"
	"strings"
)

//Producer 消息生产者
type Producer struct {
	//kafka的broker列表
	brokers []string
	//针对消息包含的数据的编码器
	dataEncoder  DataEncoder
	errorHandler AsyncProduceErrorHandler
	//同步生产者
	syncProducer sarama.SyncProducer
	//异步生产者
	asyncProducer sarama.AsyncProducer
}

//NewProducer 创建一个消息生产者。
//brokers格式为：127.0.0.1:9092,127.0.0.1:9093 多个地址半角逗号分隔
func NewProducer(brokers string, options ...ProducerOption) (*Producer, error) {
	if strings.TrimSpace(brokers) == "" {
		return nil, errors.New("未指定brokers")
	}
	addr := strings.Split(brokers, ",")
	if len(addr) == 0 {
		return nil, errors.New("brokers格式不正确，应该如 127.0.0.1:9092,127.0.0.1:9093")
	}

	pc := &producerConfig{
		ackMode: spec.WaitLeader,
		dataEncoder: func(data interface{}) ([]byte, error) {
			return json.Marshal(data)
		},
	}
	for _, v := range options {
		v(pc)
	}

	cfg := sarama.NewConfig()

	if pc.saslInfo != nil {
		cfg.Net.SASL.User = pc.saslInfo.UserName
		cfg.Net.SASL.Password = pc.saslInfo.Password
		cfg.Net.SASL.Enable = true
		cfg.Net.SASL.Mechanism = sarama.SASLTypePlaintext
	}

	cfg.Producer.RequiredAcks = convertAckMode(pc.ackMode)
	//SyncProducer本质也是AsyncProducer
	//因此Producer.Return.Errors和Successes均为true 可以处理错误和成功
	//需要注意的是一旦设置为true，使用AsyncProducer时就会有成功消息和错误消息需要处理
	cfg.Producer.Return.Errors = true
	cfg.Producer.Return.Successes = true

	sp, err := sarama.NewSyncProducer(addr, cfg)
	if err != nil {
		return nil, err
	}
	ap, err := sarama.NewAsyncProducer(addr, cfg)
	if err != nil {
		_ = sp.Close()
		return nil, err
	}

	p := &Producer{
		brokers:       addr,
		dataEncoder:   pc.dataEncoder,
		errorHandler:  pc.errorHandler,
		syncProducer:  sp,
		asyncProducer: ap,
	}

	go func() {
		//由于Producer.Return.Successes为true
		//因此必须消费成功消息
		for range ap.Successes() {
		}
	}()

	go func() {
		//由于Producer.Return.Errors为true
		//因此无论是否有errorHandler都得消费错误信息
		for e := range ap.Errors() {
			if p.errorHandler == nil {
				continue
			}
			p.errorHandler(&AsyncProduceError{
				Err: e.Err,
				Msg: e.Msg,
			})
		}
	}()

	return p, nil
}

func (p *Producer) SyncProduce(topic, key string, data interface{}) (n int, err error) {
	msg, err := p.getMsg(topic, key, data)
	if err != nil {
		return 0, err
	}
	_, _, err = p.syncProducer.SendMessage(msg)
	if err != nil {
		return 0, err
	}
	return msg.Value.Length(), nil
}

func (p *Producer) AsyncProduce(ctx context.Context, topic, key string, data interface{}) (n int, err error) {
	msg, err := p.getMsg(topic, key, data)
	if err != nil {
		return 0, err
	}
	select {
	case p.asyncProducer.Input() <- msg:
		return msg.Value.Length(), nil
	case <-ctx.Done():
		return 0, ctx.Err()
	}
}

func (p *Producer) Close() (err error) {
	_ = p.syncProducer.Close()
	_ = p.asyncProducer.Close()
	return nil
}

func (p *Producer) getMsg(topic, key string, data interface{}) (*sarama.ProducerMessage, error) {
	if strings.TrimSpace(topic) == "" {
		return nil, errors.New("topic无效")
	}
	if strings.TrimSpace(key) == "" {
		return nil, errors.New("key无效")
	}
	if data == nil {
		return nil, errors.New("data无效")
	}
	ret, err := p.dataEncoder(data)
	if err != nil {
		return nil, fmt.Errorf("data序列化失败:%w", err)
	}
	msg := &sarama.ProducerMessage{
		Topic: topic,
		Key:   sarama.StringEncoder(key),
		Value: sarama.ByteEncoder(ret),
	}
	return msg, nil
}

//producerConfig 生产者配置信息
type producerConfig struct {
	dataEncoder  DataEncoder
	errorHandler AsyncProduceErrorHandler
	ackMode      spec.AckMode
	saslInfo     *spec.SASLInfo
}

//ProducerOption 生产者配置选项
type ProducerOption func(*producerConfig)

//WithProducerAckMode 指定生产确认模式 如不指定默认为等待Leader确认
func WithProducerAckMode(mode spec.AckMode) ProducerOption {
	return func(setting *producerConfig) {
		setting.ackMode = mode
	}
}

//WithProducerSASL 指定验证信息
func WithProducerSASL(sasl spec.SASLInfo) ProducerOption {
	return func(setting *producerConfig) {
		setting.saslInfo = &sasl
	}
}

//WithProducerErrorHandler 指定异步生产的错误处理器
func WithProducerErrorHandler(handler AsyncProduceErrorHandler) ProducerOption {
	return func(setting *producerConfig) {
		setting.errorHandler = handler
	}
}

//WithProducerDataEncoder 指定消息的数据编码器 若不指定则为JSON
func WithProducerDataEncoder(e DataEncoder) ProducerOption {
	return func(setting *producerConfig) {
		setting.dataEncoder = e
	}
}

//AsyncProduceError 异步生产时的错误
type AsyncProduceError struct {
	Err error
	Msg *sarama.ProducerMessage
}

//AsyncProduceErrorHandler 异步生产的错误处理器
type AsyncProduceErrorHandler func(err *AsyncProduceError)

//DataEncoder 消息的实际数据编码器
type DataEncoder func(data interface{}) ([]byte, error)

func convertAckMode(mode spec.AckMode) sarama.RequiredAcks {
	switch mode {
	case spec.WaitNone:
		return sarama.NoResponse
	case spec.WaitLeader:
		return sarama.WaitForLocal
	case spec.WaitAll:
		return sarama.WaitForAll
	default:
		panic("mode值不正确")
	}
}
